"""Add fields for prepopulated data from alerts

Revision ID: 67f1efb93c99
Revises: dcbd2873dcfd
Create Date: 2024-07-25 17:13:04.428633

"""
import warnings
import sqlalchemy as sa
from alembic import op
from pydantic import BaseModel
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.orm import Session
from sqlalchemy import exc as sa_exc


# revision identifiers, used by Alembic.
revision = "67f1efb93c99"
down_revision = "9ba0aeecd4d0"
branch_labels = None
depends_on = None

# Define a completely separate metadata for the migration
migration_metadata = sa.MetaData()

# Direct table definition for AlertToIncident
alert_to_incident_table = sa.Table(
    'alerttoincident',
    migration_metadata,
    sa.Column('alert_id', UUID(as_uuid=False), sa.ForeignKey('alert.id', ondelete='CASCADE'), primary_key=True),
    sa.Column('incident_id', UUID(as_uuid=False), sa.ForeignKey('incident.id', ondelete='CASCADE'), primary_key=True)
)

# The following code will shoow SA warning about dialect, so we suppress it.
with warnings.catch_warnings():
    warnings.simplefilter("ignore", category=sa_exc.SAWarning)
    # Direct table definition for Incident
    incident_table = sa.Table(
        'incident',
        migration_metadata,
        sa.Column('id', UUID(as_uuid=False), primary_key=True),
        sa.Column('alerts_count', sa.Integer, default=0),
        sa.Column('affected_services', sa.JSON, default_factory=list),
        sa.Column('sources', sa.JSON, default_factory=list)
    )

# Direct table definition for Alert
alert_table = sa.Table(
    'alert',
    migration_metadata,
    sa.Column('id', UUID(as_uuid=False), primary_key=True),
    sa.Column('provider_type', sa.String),
    sa.Column('event', sa.JSON)
)


class AlertDtoLocal(BaseModel):
    service: str | None = None
    source: list[str] | None = []


def populate_db():
    session = Session(op.get_bind())

    incidents = session.execute(sa.select(incident_table)).fetchall()

    for incident in incidents:
        stmt = (
            sa.select(alert_table).select_from(alert_table)
            .join(alert_to_incident_table, alert_table.c.id == alert_to_incident_table.c.alert_id)
            .where(alert_to_incident_table.c.incident_id == str(incident.id))
        )

        alerts = session.execute(stmt).all()
        alerts_dto = [AlertDtoLocal(**alert.event) for alert in alerts]

        stmt = (
            sa.update(incident_table).where(incident_table.c.id == incident.id).values(
                sources=list(set([source for alert_dto in alerts_dto for source in alert_dto.source])),
                affected_services=list(set([alert.service for alert in alerts_dto if alert.service is not None])),
                alerts_count=len(alerts)
            )
        )
        session.execute(stmt)
    session.commit()


def upgrade() -> None:

    # ### commands auto generated by Alembic - please adjust! ###
    op.add_column("incident", sa.Column("affected_services", sa.JSON(), nullable=True))
    op.add_column("incident", sa.Column("sources", sa.JSON(), nullable=True))
    op.add_column("incident", sa.Column("alerts_count", sa.Integer(), nullable=False, server_default="0"))

    populate_db()
    # ### end Alembic commands ###


def downgrade() -> None:
    # ### commands auto generated by Alembic - please adjust! ###
    op.drop_column("incident", "alerts_count")
    op.drop_column("incident", "sources")
    op.drop_column("incident", "affected_services")
    # ### end Alembic commands ###
